import os
import random


ROOT_PATH = 'D:\\dataset\\wav'
OUT_FILENAME = '../lists/test_list.txt'
PERCENT = 80
nspeaker = 40

n = 0
def list_file(parent_path, level, id, file_dict, file_list):
    sub_files = os.listdir(parent_path)

    global n
    num = 0
    for sub_file in sub_files:
        sub_path = os.path.join(parent_path, sub_file)
        if level == 0:
            if os.path.isdir(sub_path):
                id = sub_file
                print(id)
                file_dict[id] = []
                list_file(os.path.join(parent_path, sub_file), level + 1, id, file_dict, file_dict[id])
        elif level == 1:
            if os.path.isdir(sub_path):
                list_file(os.path.join(parent_path, sub_file), level + 1, id, file_dict, file_list)
        else:
            if os.path.isfile(sub_path):
                path = sub_path[len(ROOT_PATH) + 1:]
                file_list.append(path)
                n += 1

        num += 1
        if num >= nspeaker:
            break


def random_exp(begin, end, exp):
    i = random.randint(begin, end)
    while i == exp:
        i = random.randint(begin, end)

    return i


if __name__ == '__main__':

    file_dict = {}
    list_file(ROOT_PATH, 0, '', file_dict, None)

    keys = list(file_dict.keys())

    fs = open(OUT_FILENAME, 'w')

    sel_num = int(n * PERCENT / 100)
    for i in range(0, sel_num):
        index = random.randint(0, len(keys) - 1)
        id0 = keys[index]

        if i % 2 == 0:
            file_list = file_dict[id0]
            index = random.randint(0, len(file_list) - 1)
            wave0 = file_list[index]
            index = random_exp(0, len(file_list) - 1, index)
            wave1 = file_list[index]
            line = '{} {} {}\n'.format(1, wave0, wave1)
            fs.write(line)
        else:
            index = random_exp(0, len(keys) - 1, index)
            id1 = keys[index]
            file_list = file_dict[id0]
            wave0 = file_list[random.randint(0, len(file_list) - 1)]
            file_list = file_dict[id1]
            wave1 = file_list[random.randint(0, len(file_list) - 1)]
            line = '{} {} {}\n'.format(0, wave0, wave1)
            fs.write(line)
    fs.close()
